package com.mars.security.core.validate.code;

import com.mars.security.core.properties.SecurityProperties;
import com.mars.security.core.validate.code.image.ImageCode;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.social.connect.web.HttpSessionSessionStrategy;
import org.springframework.social.connect.web.SessionStrategy;
import org.springframework.util.AntPathMatcher;
import org.springframework.web.bind.ServletRequestBindingException;
import org.springframework.web.bind.ServletRequestUtils;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.HashSet;
import java.util.Set;

/**
 * 自定义短信验证码过滤器
 * @author MARS
 * @date 2018/7/26
 */


public class SmsCodeFilter extends OncePerRequestFilter implements InitializingBean{

    /**
     * 认证失败处理器：
     * 设置set/get方法
     *
     */
    private AuthenticationFailureHandler authenticationFailureHandler;

    /**
     * 获取session的工具类
     */
    private SessionStrategy sessionStrategy = new HttpSessionSessionStrategy();

    /**
     *  设置一个set，存放拦截的url
     *
     */
    private Set<String> urls = new HashSet<>();

    /**
     * 读取配置
     */
    private SecurityProperties securityProperties;

    /**
     * 正则工具类
     */
    private AntPathMatcher antPathMatcher = new AntPathMatcher();

    /**
     * 实现 InitializingBean 的afterPropertiesSet()
     *
     * @throws ServletException
     */
    @Override
    public void afterPropertiesSet() throws ServletException {
        super.afterPropertiesSet();
        String[] configUrls = StringUtils.splitByWholeSeparatorPreserveAllTokens(securityProperties.getCode().getSms().getUrl(),",");
        for (String configUrl : configUrls) {
            urls.add(configUrl);
        }
        // 登陆的请求拦截
        urls.add("/authentication/mobile");
    }



    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
                                    FilterChain filterChain) throws ServletException, IOException {
        boolean action = false;
        for (String url : urls){
            if(antPathMatcher.match(url,request.getRequestURI())){
                action = true;
            }
        }


        if(action){
            try {
                validate(new ServletWebRequest(request));
            } catch (ValidateCodeException e) {
                authenticationFailureHandler.onAuthenticationFailure(request,response,e);
                return ;
            }
        }

        // 如果不是该请求，直接向下调用
        filterChain.doFilter(request,response);
    }


    private void validate(ServletWebRequest request) throws ServletRequestBindingException{
        // 获取session中生成的code
        ValidateCode codeInSession = (ValidateCode)sessionStrategy.getAttribute(request,ValidateCodeProcessor.SESSION_KEY_PREFIX + "SMS");
        // 获取页面中用户输入的code
        String codeInRequest = ServletRequestUtils.getStringParameter(request.getRequest(),"smsCode");

        if (StringUtils.isBlank(codeInRequest)) {
            throw new ValidateCodeException("验证码不能为空");
        }

        if (codeInSession == null) {
            throw new ValidateCodeException("验证码不存在");
        }

        if (codeInSession.isExpired()) {
            sessionStrategy.removeAttribute(request,ValidateCodeProcessor.SESSION_KEY_PREFIX + "SMS");
            throw new ValidateCodeException("验证码已过期");
        }

        if (!StringUtils.equals(codeInSession.getCode(),codeInRequest)) {
            throw new ValidateCodeException("验证码不匹配");
        }

        sessionStrategy.removeAttribute(request,ValidateCodeProcessor.SESSION_KEY_PREFIX + "SMS");
    }

    public AuthenticationFailureHandler getAuthenticationFailureHandler() {
        return authenticationFailureHandler;
    }

    public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
        this.authenticationFailureHandler = authenticationFailureHandler;
    }

    public SessionStrategy getSessionStrategy() {
        return sessionStrategy;
    }

    public void setSessionStrategy(SessionStrategy sessionStrategy) {
        this.sessionStrategy = sessionStrategy;
    }

    public Set<String> getUrls() {
        return urls;
    }

    public void setUrls(Set<String> urls) {
        this.urls = urls;
    }

    public SecurityProperties getSecurityProperties() {
        return securityProperties;
    }

    public void setSecurityProperties(SecurityProperties securityProperties) {
        this.securityProperties = securityProperties;
    }
}
